#!/usr/bin/python
# -*- coding: utf-8 -*-
import json
import os
import time
import logging
import threading
import requests
from singleton import singleton
from thread_result import ThreadResult
from threadpool import ThreadPool
from valid_param import valid_param, multi_type, null_type


class Downloader(object):

    @valid_param(session=null_type(requests.Session))
    def __init__(self, session=None) -> None:
        super().__init__()
        # 会话对象
        self._session = session
        if not self._session:
            self._session = requests.session()
            self._session.headers = {
                'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
                              'Chrome/57.0.2987.133 Safari/537.36',
                'Connection': 'Keep-alive',
                'Accept-Ranges': 'bytes',
                'Accept-Language': 'zh-CN',
                'Accept-Encoding': 'gzip, deflate, sdch',
                'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8'
            }

    @valid_param(url=(str, '0 != len(x)'), file_path=null_type(str), file_name=null_type(str),
                 forced=bool, retry_count=int, thread_num=int)
    def download(self, url, file_dir='', file_name='', forced=False, resume=True, retry_count=0, thread_num=5):
        '''
        download file
        :param url: url address
        :param file_dir: file save dir, if it is empty(None or empty string), the current path is used
        :param file_name: file name, if it is empty(None or empty string), the last part of url is used
        :param forced: whether to force the download, default is False
        :param resume: whether to continue from the previous location to download, default is True
        :param retry_count: retry count, default is 0
        :param thread_num: the number of subroutines downloaded. default is 5
        :return: result
        '''
        for i in range(retry_count+1):
            result = self._download(url, file_dir, file_name, forced, resume, thread_num)
            if result:
                return result
            if i != retry_count:
                logging.info('retry-%s download %s.' % (i+1, url))
        return False

    def _download(self, url, file_dir, file_name, forced, resume, thread_num):
        # default parameter
        file_dir = self._format_directory(file_dir)
        file_name = self._format_file_name(file_name, url)
        file_path = file_dir + file_name
        info_dir = file_dir + 'info/'
        info_path = info_dir + file_name + '.info'
        temp_path = file_dir + file_name + '.download'

        # 如果文件存在 且 不要求强制下载
        if os.path.exists(file_path) and not forced:
            logging.info('file [%s] already exists, no download required' % file_path)
            return True

        # get file size
        try:
            res = self._session.head(url)
            file_size = int(res.headers.get('content-length'))
            if file_size:
                logging.debug('url [%s] size is %s' % (url, self._conversion_size_unit(file_size)))
            else:
                logging.error('url [%s] does not exist.' % url)
                return False
        except Exception as e:
            logging.error('Exception %s.' % e)
            logging.error('url [%s] does not support download.' % url)
            return False

        # info
        info_dict = {}
        if os.path.exists(info_path) and os.path.exists(temp_path):
            # 如果保存有 info文件 与 temp文件, 则读取以前的信息
            with open(info_path, 'r') as info_fp:
                info_dict = json.load(info_fp)
        else:
            # 否则清空记录
            if os.path.exists(info_path):
                os.remove(info_path)
            if os.path.exists(temp_path):
                os.remove(temp_path)

        info_dict['url'] = url
        info_dict['path'] = file_path
        info_dict['name'] = file_name
        info_dict['size'] = file_size
        info_dict['time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
        if not info_dict.get('part'):
            info_dict['part'] = {}

        logging.info('ready to download the file url [%s] to the path [%s]. file size is %s' %
                     (url, file_path, self._conversion_size_unit(file_size)))

        # if the directory does not exist, create the directory
        if not os.path.exists(info_dir):
            os.makedirs(info_dir)

        # if the directory does not exist, create the directory
        if not os.path.exists(file_dir):
            os.makedirs(file_dir)

        # 如果文件不存在, 则创建文件
        if not os.path.exists(temp_path):
            with open(temp_path, 'w+') as fp:
                fp.truncate(file_size)

        # 启动多线程写文件
        start_time = time.clock()
        lock = threading.Lock()
        # 开启线程池
        if thread_num <= 0:
            thread_num = 5
        pool = ThreadPool(thread_num)
        # 此处不能使用 'a' 模式
        # 因为 'a' 模式下, write() 操作追加到 end, 而忽略当前的 seek 位置
        with open(temp_path, 'rb+') as fp:
            # 下载单元的大小
            DOWNLOAD_PART_SIZE = 1024*1024

            # 计算任务数量
            task_num = (file_size + DOWNLOAD_PART_SIZE - 1) // DOWNLOAD_PART_SIZE

            for i in range(task_num):
                start = DOWNLOAD_PART_SIZE * i
                if i != task_num - 1:
                    # 需要减去与上一个块重合的一个字节
                    end = start + DOWNLOAD_PART_SIZE - 1
                else:
                    # 最后一块需要加上整除剩下的字节, 也就是直到末尾
                    end = file_size
                if resume and info_dict['part'].get('%s-%s' % (start, end)):
                    continue
                pool.apply_async(self._download_part, (url, start, end, fp, info_dict, lock))
            pool.close()
            pool.wait()

        # 判断结果
        result = True
        for task in pool.task_list():
            if not task.get_result():
                result = False

        info_dict['result'] = result
        with open(info_path, 'w+') as info_fp:
            info_fp.write(json.dumps(info_dict))

        if result and 0 != len(pool.task_list()):
            os.rename(temp_path, file_path)
            logging.info('file [%s] size [%s]. download complete. time: %0.3f s' %
                         (file_path, self._conversion_size_unit(file_size), time.clock() - start_time))
            return True
        else:
            logging.error('file [%s] download failure.' % file_path)
            return False

    @valid_param(url=(str, '0 != len(x)'), start=int, end=int)
    def _download_part(self, url, start, end, fp, info_dict, lock):
        '''
        download file part
        :param url: url
        :param start: start
        :param end: end
        :param fp: file handle
        :param lock: threading lock
        :return: result
        '''
        part_headers = {
            'Range': 'bytes=%d-%d' % (start, end)
        }
        try:
            logging.debug('start download url [%s] part [%s - %s].' % (url,
                                                                       self._conversion_size_unit(start),
                                                                       self._conversion_size_unit(end)))
            r = self._session.get(url, headers=part_headers, timeout=30, stream=True)
            content = r.content
            if not content:
                return False
            if lock:
                lock.acquire()
            fp.seek(start)
            fp.write(content)
            fp.flush()

            info_dict['part']['%s-%s' % (start, end)] = True

            if lock:
                lock.release()
            logging.debug('download url [%s] part [%s - %s] complete.' % (url,
                                                                          self._conversion_size_unit(start),
                                                                          self._conversion_size_unit(end)))
            return True
        except Exception as e:
            logging.error('Exception %s.' % e)
            logging.error('download url [%s] part [%s - %s] failure.' % (url,
                                                                         self._conversion_size_unit(start),
                                                                         self._conversion_size_unit(end)))

    @valid_param(file_name=multi_type(str, None), url=(str, '0 != len(x)'))
    def _format_file_name(self, file_name, url):
        '''
        format file name
        :param file_name: file name
        :param url: url
        :return: format file path
        '''
        if file_name is None or 0 == len(file_name):
            file_name = url.split('/')[-1]
        file_name = file_name.replace('<', ' ').replace('>', ' ').replace('|', ' ')\
            .replace(':', ' ').replace('\"', ' ').replace('*', ' ').replace('?', ' ')\
            .replace('/', ' ').replace('\\', ' ')
        return file_name

    @valid_param(dir=multi_type(str, None))
    def _format_directory(self, directory):
        '''
        format directory
        :param directory: directory
        :return: format directory
        '''
        if directory is None or 0 == len(directory):
            directory = './'
        elif '\\' != directory[-1] and '/' != directory[-1]:
            directory += '/'
        directory = directory.replace('<', ' ').replace('>', ' ').replace('|', ' ')\
            .replace(':', ' ').replace('\"', ' ').replace('*', ' ').replace('?', ' ')
        return directory

    @valid_param(size=int)
    def _conversion_size_unit(self, size):
        if not isinstance(size, int):
            raise TypeError('size type not is int type.')
        if size < 1024:
            return '%dB' % size
        elif 1024 <= size < 1024 * 1024:
            return '%0.3fKB' % (size / 1024)
        elif 1024 * 1024 <= size < 1024 * 1024 * 1024:
            return '%0.3fMB' % (size / 1024 / 1024)
        elif 1024 * 1024 * 1024 <= size:
            return '%0.3fGB' % (size / 1024 / 1024 / 1024)
